{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyNIy6uRivivL3VYs0eNmj01",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/Baijiong-Lin/LoRA-Torch/blob/main/examples/linear.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-j81b-ktkWH0",
        "outputId": "ca0a1e95-ae6c-44ac-cee0-5bf75f8fd3b1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://github.com/microsoft/LoRA\n",
            "  Cloning https://github.com/microsoft/LoRA to /tmp/pip-req-build-e5ulip2v\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/microsoft/LoRA /tmp/pip-req-build-e5ulip2v\n",
            "  Resolved https://github.com/microsoft/LoRA to commit 998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting git+https://github.com/Baijiong-Lin/LoRA-Torch\n",
            "  Cloning https://github.com/Baijiong-Lin/LoRA-Torch to /tmp/pip-req-build-l221wgno\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/Baijiong-Lin/LoRA-Torch /tmp/pip-req-build-l221wgno\n",
            "  Resolved https://github.com/Baijiong-Lin/LoRA-Torch to commit 4b550558ff1c7ef6bd1009b80d30e52069396f69\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install git+https://github.com/microsoft/LoRA\n",
        "!pip install git+https://github.com/Baijiong-Lin/LoRA-Torch"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch, loralib, loratorch, copy\n",
        "import torch.nn as nn"
      ],
      "metadata": {
        "id": "tDo6S6QPkiXI"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_lib = loralib.Linear(5, 6, r=4, lora_alpha=1)\n",
        "model_torch = loratorch.Linear(5, 6, r=4, lora_alpha=1)\n",
        "\n",
        "model_lib.weight.data = copy.deepcopy(model_torch.weight.data)\n",
        "model_lib.lora_A.data = copy.deepcopy(model_torch.w_lora_A.data)\n",
        "model_lib.lora_B.data = copy.deepcopy(model_torch.w_lora_B.data)\n",
        "model_lib.bias.data = copy.deepcopy(model_torch.bias.data)"
      ],
      "metadata": {
        "id": "_2ceOeiIk8xp"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "loralib.mark_only_lora_as_trainable(model_lib)\n",
        "loratorch.mark_only_lora_as_trainable(model_torch)"
      ],
      "metadata": {
        "id": "CF_PCLIslcJl"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "optimizer_lib = torch.optim.SGD(model_lib.parameters(), lr=0.1)\n",
        "optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.1)\n",
        "\n",
        "for _ in range(3):\n",
        "    model_lib.train()\n",
        "    model_torch.train()\n",
        "    x = torch.rand(2, 5)\n",
        "\n",
        "    loss1 = model_lib(x).sum()\n",
        "    optimizer_lib.zero_grad()\n",
        "    loss1.backward()\n",
        "    optimizer_lib.step()\n",
        "\n",
        "    loss2 = model_torch(x).sum()\n",
        "    optimizer_torch.zero_grad()\n",
        "    loss2.backward()\n",
        "    optimizer_torch.step()\n",
        "\n",
        "    # for k, v in model_lib.named_parameters():\n",
        "    #     print(k, v.grad)\n",
        "\n",
        "    # for k, v in model_torch.named_parameters():\n",
        "    #     print(k, v.grad)\n",
        "    print(torch.isclose(model_lib.lora_A.grad, model_torch.w_lora_A.grad))\n",
        "    print(torch.isclose(model_lib.lora_B.grad, model_torch.w_lora_B.grad))\n",
        "\n",
        "    x_test = torch.rand(3, 5)\n",
        "    model_lib.eval()\n",
        "    model_torch.eval()\n",
        "    print(torch.isclose(model_lib(x_test), model_torch(x_test)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GdmSBAenkvP1",
        "outputId": "7e186013-f7a2-4fc6-90c3-854f1f1fc2fb"
      },
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "tensor([[True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True]])\n",
            "tensor([[True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True]])\n",
            "tensor([[True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True]])\n",
            "tensor([[True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True]])\n",
            "tensor([[True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True]])\n",
            "tensor([[True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True]])\n",
            "tensor([[True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True],\n",
            "        [True, True, True, True, True]])\n",
            "tensor([[True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True],\n",
            "        [True, True, True, True]])\n",
            "tensor([[True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True],\n",
            "        [True, True, True, True, True, True]])\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Ul2JBtuQmIuG"
      },
      "execution_count": 5,
      "outputs": []
    }
  ]
}